# Causal Language Model Training for Demographic Bias Mitigation

This repository contains the implementation for our ICLR submission on causal language model training methods to mitigate demographic bias. The code includes several training approaches: Supervised Fine-Tuning (SFT), Causal PPO, Simple PPO, and Adversarial PPO.

## 🚀 Quick Start

### 1. Environment Setup

First, create a Python virtual environment and install the required dependencies:

```bash
# Create virtual environment
python -m venv .venv

# Activate the environment
source .venv/bin/activate  # On Linux/Mac
# or
.venv\Scripts\activate  # On Windows

# Install requirements
pip install -r requirements.txt
```

### 2. Dataset Setup

#### 2.1 Download DiscrimEval Dataset

Download the DiscrimEval dataset from Hugging Face using Git LFS:

```bash
# Install git-lfs if not already installed
sudo apt-get install git-lfs  # On Ubuntu/Debian

# Initialize git-lfs
git lfs install

# Clone the DiscrimEval dataset
git clone https://huggingface.co/datasets/Anthropic/discrim-eval

# Create dataset directory and move files
mkdir -p dataset
cp -r discrim-eval/* dataset/


```

#### 2.2 Download HH-RLHF Dataset

Download the HH-RLHF training dataset:

```bash
# Create dataset directory if it doesn't exist
mkdir -p dataset

# Download the HH-RLHF training data
wget https://github.com/anthropics/hh-rlhf/raw/refs/heads/master/harmless-base/train.jsonl.gz

# Extract and rename to train.jsonl
gunzip train.jsonl.gz
mv train.jsonl dataset/train.jsonl
```

### 3. Model Access Setup

#### 3.1 Request Llama-3-8B Access (you can load other models that don't require request access)

1. Visit the Llama-3-8B model page: https://huggingface.co/meta-llama/Meta-Llama-3-8B
2. Click "Request access" and fill out the form
3. Wait for approval from Meta (usually takes a few hours to a few days) 

#### 3.2 Hugging Face Authentication

After getting access approval, authenticate with Hugging Face:

```bash
# Install huggingface-cli if not already installed
pip install huggingface_hub

# Login to Hugging Face (you'll be prompted for your token)
huggingface-cli login

# Or login with token directly
huggingface-cli login --token YOUR_HF_TOKEN_HERE --add-to-git-credential
```

To get your Hugging Face token:
1. Go to https://huggingface.co/settings/tokens
2. Create a new token with "Read" permissions
3. Copy the token for authentication

## 📝 Running the Experiments

The repository contains four main training scripts that can be run with different demographic attributes:

### Available Scripts

1. **`discrim_eval_separate_sft.py`** - Supervised Fine-Tuning baseline
2. **`discrim_eval_separate_causal_ppo.py`** - Causal PPO with backdoor adjustment
3. **`discrim_eval_separate_ppo.py`** - Simple PPO without causal adjustment
4. **`discrim_eval_separate_ppo_adversarial.py`** - Adversarial PPO training

### Available Demographics

- `race` - Racial/ethnic bias evaluation
- `age` - Age-based bias evaluation  
- `gender` - Gender-based bias evaluation

### Running the Scripts

Execute each script with the desired demographic:

```bash
# Supervised Fine-Tuning
python3 discrim_eval_separate_sft.py --demographic race
python3 discrim_eval_separate_sft.py --demographic age
python3 discrim_eval_separate_sft.py --demographic gender

# Causal PPO (our main method)
python3 discrim_eval_separate_causal_ppo.py --demographic race
python3 discrim_eval_separate_causal_ppo.py --demographic age
python3 discrim_eval_separate_causal_ppo.py --demographic gender

# Simple PPO baseline
python3 discrim_eval_separate_ppo_v2.py --demographic race
python3 discrim_eval_separate_ppo_v2.py --demographic age
python3 discrim_eval_separate_ppo_v2.py --demographic gender

# Adversarial PPO
python3 discrim_eval_separate_ppo_adversarial.py --demographic race
python3 discrim_eval_separate_ppo_adversarial.py --demographic age
python3 discrim_eval_separate_ppo_adversarial.py --demographic gender
```

### Example Usage

```bash
# Run causal PPO training for race demographic
python3 discrim_eval_separate_causal_ppo.py --demographic race

# This will:
# 1. Load and preprocess the datasets
# 2. Train a confounder predictor
# 3. Train a demographic-aware reward model
# 4. Perform PPO training with causal adjustment
# 5. Evaluate on DiscrimEval benchmark
```

## 🔧 Configuration

### Hardware Requirements

- **GPU**: CUDA-compatible GPU with at least 24GB VRAM (recommended: H100, A100, or V100)
- **RAM**: At least 32GB system RAM
- **Storage**: At least 50GB free space for datasets and model checkpoints

### Training Parameters

Key hyperparameters can be modified in the configuration sections of each script:

- `batch_size`: Training batch size (default: 4)
- `learning_rate`: Learning rate (default: 5e-5)
- `num_epochs`: Number of training epochs (default: 4)
- `max_length`: Maximum sequence length (default: 512)

## 📊 Output and Results

Each script will:

**Print detailed metrics** including:
   - Overall accuracy on DiscrimEval
   - Per-demographic group performance
   - Fairness metric (accuracy gaps)
   - Bias mitigation effectiveness


## 🏗️ Project Structure

```
.
├── README.md                                    # This file
├── requirements.txt                            # Python dependencies
├── dataset/                                   # Datasets directory
│   ├── train.jsonl                           # HH-RLHF training data
│   └── [discrim-eval]                   # DiscrimEval evaluation data
├── discrim_eval_separate_sft.py              # SFT baseline
├── discrim_eval_separate_causal_ppo.py       # Causal PPO (main method)
├── discrim_eval_separate_ppo.py           # Simple PPO baseline
├── discrim_eval_separate_ppo_adversarial.py  # Adversarial PPO
```

## 🐛 Troubleshooting

### Common Issues

1. **CUDA Out of Memory**
   - Reduce `batch_size` in the script configuration

2. **Hugging Face Authentication Error**
   - Ensure you have requested and received access to Llama-3-8B
   - Verify your token has the correct permissions
   - Try logging out and logging back in: `huggingface-cli logout && huggingface-cli login`

3. **Dataset Loading Error**
   - Verify dataset files are in the correct `dataset/` directory
   - Check file permissions and ensure files are not corrupted
   - For large files, ensure sufficient disk space

4. **Import Errors**
   - Ensure all requirements are installed: `pip install -r requirements.txt`
   - Check Python version compatibility (requires Python 3.8+)


## 📚 Dependencies

Key dependencies (see `requirements.txt` for full list):

- **torch**: PyTorch framework
- **transformers**: Hugging Face transformers library
- **trl**: Transformer Reinforcement Learning library
- **peft**: Parameter Efficient Fine-Tuning
- **accelerate**: Distributed training support
- **bitsandbytes**: 8-bit optimization
- **scikit-learn**: Machine learning utilities
- **evaluate**: Model evaluation metrics


